/*
 * Lincheck
 *
 * Copyright (C) 2019 - 2025 JetBrains s.r.o.
 *
 * This Source Code Form is subject to the terms of the
 * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed
 * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

package org.jetbrains.lincheck.trace

import org.jetbrains.lincheck.descriptors.toType
import org.jetbrains.lincheck.trace.CompressingPostprocessor.compressAccessPairs

/**
 * Interface for implementing trace points modification algorithms in order
 * for trace to look better when viewed in the plugin. This interface is
 * used when the tree structure could potentially change and `TRPointPrinter`'s
 * cannot help (see `TRTracePointPrinters.kt`).
 */
interface TracePostprocessor {
    /**
     * Takes a [tracePoint] and returns a modified one, which later will be used to load the next children.
     * So the children modification will cause the tree structures to change.
     * [reader] could be used to read the children of the [tracePoint].
     */
    fun postprocess(reader: LazyTraceReader, tracePoint: TRTracePoint): TRTracePoint
}

object CompressingPostprocessor : TracePostprocessor {

    override fun postprocess(reader: LazyTraceReader, tracePoint: TRTracePoint): TRTracePoint {
        return tracePoint
            .compressDefaultPairs(reader)
            .compressAccessPairs(reader)
            .compressSyntheticFieldAccess(reader)
            .compressAutoGeneratedFieldAccess(reader)
    }

    /**
     * Compresses `fun$default(...)` calls.
     *
     * Kotlin functions with default values are represented as two nested calls in the stack trace.
     *
     * For example:
     *
     * ```
     * A.callMe$default(A#1, 3, null, 2, null) at A.operation(A.kt:23)
     *   A.callMe(3, "Hey") at A.callMe$default(A.kt:27)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A.callMe(3, "Hey") at A.operation(A.kt:23)
     * ```
     *
     */
    private fun TRTracePoint.compressDefaultPairs(reader: LazyTraceReader): TRTracePoint {
        if (this !is TRMethodCallTracePoint) return this
        if (events.size != 1) return this
        val singleChild = loadChild(reader, tracePoint = this, 0) ?: return this
        if (singleChild !is TRMethodCallTracePoint) return this
        if (
            className != singleChild.className ||
            !isDefaultPair(methodName, singleChild.methodName)
        ) return this

        return combineCallNodes(this, singleChild)
    }

    /**
     * Compresses `.access$` calls.
     *
     * The `.access$` methods are generated by the Kotlin compiler to access otherwise inaccessible members
     * (e.g., private) from lambdas, inner classes, etc.
     *
     * For example:
     *
     * ```
     * A.access$callMe() at A.operation(A.kt:N)
     *  A.callMe() at A.access$callMe(A.kt:N)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A.callMe() at A.operation(A.kt:N)
     * ```
     */
    private fun TRTracePoint.compressAccessPairs(reader: LazyTraceReader): TRTracePoint {
        if (this !is TRMethodCallTracePoint) return this
        if (events.size != 1) return this
        val singleChild = loadChild(reader, tracePoint = this, 0) ?: return this
        if (singleChild !is TRMethodCallTracePoint) return this
        if (
            className != singleChild.className ||
            !isAccessPair(methodName, singleChild.methodName)
        ) return this

        return combineCallNodes(this, singleChild)
    }

    /**
     * Compresses synthetic field access methods (`access$get` and `access$set`).
     *
     * These methods are generated by the Kotlin compiler when a lambda or inner class accesses a private field.
     * This function removes the synthetic access method and keeps only the actual field access event.
     *
     * For `get` version of the method:
     *
     * ```
     * A.access$get_field(A#1) at A.operation(A.kt:N)
     *   this ➜ A#1 at A.access$get_field(A.kt:N)
     *   A#1.field ➜ value at A.access$get_field(A.kt:N)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A#1.field ➜ value at A.operation(A.kt:N)
     * ```
     *
     * For `set` version of the method:
     *
     * ```
     * A.access$set_field(A#1) at A.operation(A.kt:N)
     *   this ➜ A#1 at A.access$set_field(A.kt:N)
     *   <set-?> ➜ value at A.access$set_field(A.kt:N)
     *   A#1.field = value at A.access$set_field(A.kt:N)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A#1.field = value at A.operation(A.kt:N)
     * ```
     *
     * This is different from `fun$access`, which is addressed in [compressAccessPairs].
     */
    private fun TRTracePoint.compressSyntheticFieldAccess(reader: LazyTraceReader): TRTracePoint {
        if (this !is TRMethodCallTracePoint) return this
        if (events.size != 2 && events.size != 3) return this
        if (!isSyntheticFieldAccess(methodName)) return this

        // get and set method versions have different number of children (2 and 3, correspondingly)
        // the correct access node is stored at the last position
        val childIndex = if (isSyntheticGetFieldAccess(methodName)) 1 else 2
        val actualNode = loadChild(reader, tracePoint = this, childIndex) ?: return this

        // For proper clicking in the trace recorder plugin sidebar to work,
        // code location must be rewritten with the one from the parent trace point
        val newNode = when (actualNode) {
            is TRReadTracePoint -> TRReadTracePoint(actualNode.threadId, codeLocationId, actualNode.fieldId,
                                                    actualNode.obj, actualNode.value, actualNode.eventId)
            is TRWriteTracePoint -> TRWriteTracePoint(actualNode.threadId, codeLocationId, actualNode.fieldId,
                                                      actualNode.obj, actualNode.value, actualNode.eventId)
            else -> return this
        }

        return newNode
    }

    /**
     * Compresses `.[get|set]FieldName` calls to default implementations of field getters and setters.
     *
     * For `get` version of the method:
     *
     * ```
     * A.getFieldName() at A.operation(A.kt:N)
     *   A.fieldName ➜ value at A.getFieldName(A.kt:K)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A.fieldName ➜ value at A.operation(A.kt:N)
     * ```
     *
     * For `set` version of the method:
     *
     * ```
     * A.setFieldName(value) at A.operation(A.kt:N)
     *  <set-?> ➜ value at A.setFieldName(A.kt:N)
     *  A.fieldName = value at A.setFieldName(A.kt:N)
     * ```
     *
     * will be collapsed into:
     *
     * ```
     * A.fieldName = value at A.operation(A.kt:N)
     * ```
     *
     * *Note*: in case if a user defines a custom getter of setter, then it will not be compressed.
     */
    private fun TRTracePoint.compressAutoGeneratedFieldAccess(reader: LazyTraceReader): TRTracePoint {
        if (this !is TRMethodCallTracePoint) return this
        // Auto-generated getter has a single child and setter has two of them.
        if (events.size != 1 && events.size != 2) return this
        // Note: when constructing new trace points (TRReadTracePoint or TRWriteTracePoint),
        // for proper clicking in the trace recorder plugin sidebar to work,
        // code location must be rewritten with the one from the parent trace point

        val child1 = loadChild(reader, tracePoint = this, 0) ?: return this
        if (events.size == 1) {
            val isAutoGeneratedFieldGetter = (
                child1 is TRReadTracePoint && className == child1.className &&
                // method signature is equal to the getter
                child1.value != null && // if `value` was not loaded successfully from disk, then do not compress
                parameters.isEmpty() && methodDescriptor.returnType == child1.value.className.toType() &&
                isAutoGeneratedFieldGetterName(methodName, child1.name)
            )

            if (isAutoGeneratedFieldGetter) {
                return TRReadTracePoint(child1.threadId, codeLocationId, child1.fieldId,
                                        child1.obj, child1.value, child1.eventId)
            }
        }
        else /* events.size == 2 */ {
            val child2 = loadChild(reader, tracePoint = this, 1) ?: return this
            val autoGeneratedSetterParamName = "<set-?>"
            // there is no need to check for the method signature of the setter,
            // because user code cannot contain variables named `<set-?>`
            val isAutoGeneratedFieldSetter = (
                child1 is TRLocalVariableTracePoint && child1.name == autoGeneratedSetterParamName &&
                child2 is TRWriteTracePoint && className == child2.className &&
                isAutoGeneratedFieldSetterName(methodName, child2.name)
            )

            if (isAutoGeneratedFieldSetter) return TRWriteTracePoint(child2.threadId, codeLocationId, child2.fieldId,
                                                                     child2.obj, child2.value, child2.eventId)
        }

        return this
    }

    private fun loadChild(reader: LazyTraceReader, tracePoint: TRMethodCallTracePoint, childIdx: Int): TRTracePoint? {
        if (childIdx !in 0..tracePoint.events.lastIndex) return null
        return tracePoint.events[childIdx] ?: reader.getChildAndRestorePosition(tracePoint, childIdx)
    }

    private fun combineCallNodes(parent: TRMethodCallTracePoint, child: TRMethodCallTracePoint): TRMethodCallTracePoint {
        val newNode = TRMethodCallTracePoint(
            child.threadId,
            // Code location is taken from the parent trace point and not from the child
            // because in the trace sidebar of the plugin for proper clicking to work
            // the code location reference must be the same as for the parent
            parent.codeLocationId,
            child.methodId,
            child.obj,
            child.parameters,
            child.eventId
        )
        newNode.result = child.result
        newNode.exceptionClassName = child.exceptionClassName
        newNode.replaceChildren(child)
        return newNode
    }
}